#***************************************************************************************/
#
#    Based on Pointnet2 Library (MIT license):
#    https://github.com/sshaoshuai/Pointnet2.PyTorch
#
#    Copyright (c) 2019 Shaoshuai Shi

#    Permission is hereby granted, free of charge, to any person obtaining a copy
#    of this software and associated documentation files (the "Software"), to deal
#    in the Software without restriction, including without limitation the rights
#    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#    copies of the Software, and to permit persons to whom the Software is
#    furnished to do so, subject to the following conditions:

#    The above copyright notice and this permission notice shall be included in all
#    copies or substantial portions of the Software.

#    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#    SOFTWARE.
#
#***************************************************************************************/

import torch.nn as nn
from typing import List, Tuple


class SharedMLP(nn.Sequential):

    def __init__(
        self,
        args: List[int],
        *,
        bn: bool = False,
        activation=nn.ReLU(inplace=True),
        preact: bool = False,
        first: bool = False,
        name: str = "",
        instance_norm: bool = False,
    ):
        super().__init__()

        for i in range(len(args) - 1):
            self.add_module(
                name + 'layer{}'.format(i),
                Conv2d(args[i],
                       args[i + 1],
                       bn=(not first or not preact or (i != 0)) and bn,
                       activation=activation if (not first or not preact or
                                                 (i != 0)) else None,
                       preact=preact,
                       instance_norm=instance_norm))


class _ConvBase(nn.Sequential):

    def __init__(self,
                 in_size,
                 out_size,
                 kernel_size,
                 stride,
                 padding,
                 activation,
                 bn,
                 init,
                 conv=None,
                 batch_norm=None,
                 bias=True,
                 preact=False,
                 name="",
                 instance_norm=False,
                 instance_norm_func=None):
        super().__init__()

        bias = bias and (not bn)
        conv_unit = conv(in_size,
                         out_size,
                         kernel_size=kernel_size,
                         stride=stride,
                         padding=padding,
                         bias=bias)
        init(conv_unit.weight)
        if bias:
            nn.init.constant_(conv_unit.bias, 0)

        if bn:
            if not preact:
                bn_unit = batch_norm(out_size)
            else:
                bn_unit = batch_norm(in_size)
        if instance_norm:
            if not preact:
                in_unit = instance_norm_func(out_size,
                                             affine=False,
                                             track_running_stats=False)
            else:
                in_unit = instance_norm_func(in_size,
                                             affine=False,
                                             track_running_stats=False)

        if preact:
            if bn:
                self.add_module(name + 'bn', bn_unit)

            if activation is not None:
                self.add_module(name + 'activation', activation)

            if not bn and instance_norm:
                self.add_module(name + 'in', in_unit)

        self.add_module(name + 'conv', conv_unit)

        if not preact:
            if bn:
                self.add_module(name + 'bn', bn_unit)

            if activation is not None:
                self.add_module(name + 'activation', activation)

            if not bn and instance_norm:
                self.add_module(name + 'in', in_unit)


class _BNBase(nn.Sequential):

    def __init__(self, in_size, batch_norm=None, name=""):
        super().__init__()
        self.add_module(name + "bn", batch_norm(in_size))

        nn.init.constant_(self[0].weight, 1.0)
        nn.init.constant_(self[0].bias, 0)


class BatchNorm1d(_BNBase):

    def __init__(self, in_size: int, *, name: str = ""):
        super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)


class BatchNorm2d(_BNBase):

    def __init__(self, in_size: int, name: str = ""):
        super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)


class Conv1d(_ConvBase):

    def __init__(self,
                 in_size: int,
                 out_size: int,
                 *,
                 kernel_size: int = 1,
                 stride: int = 1,
                 padding: int = 0,
                 activation=nn.ReLU(inplace=True),
                 bn: bool = False,
                 init=nn.init.kaiming_normal_,
                 bias: bool = True,
                 preact: bool = False,
                 name: str = "",
                 instance_norm=False):
        super().__init__(in_size,
                         out_size,
                         kernel_size,
                         stride,
                         padding,
                         activation,
                         bn,
                         init,
                         conv=nn.Conv1d,
                         batch_norm=BatchNorm1d,
                         bias=bias,
                         preact=preact,
                         name=name,
                         instance_norm=instance_norm,
                         instance_norm_func=nn.InstanceNorm1d)


class Conv2d(_ConvBase):

    def __init__(self,
                 in_size: int,
                 out_size: int,
                 *,
                 kernel_size: Tuple[int, int] = (1, 1),
                 stride: Tuple[int, int] = (1, 1),
                 padding: Tuple[int, int] = (0, 0),
                 activation=nn.ReLU(inplace=True),
                 bn: bool = False,
                 init=nn.init.kaiming_normal_,
                 bias: bool = True,
                 preact: bool = False,
                 name: str = "",
                 instance_norm=False):
        super().__init__(in_size,
                         out_size,
                         kernel_size,
                         stride,
                         padding,
                         activation,
                         bn,
                         init,
                         conv=nn.Conv2d,
                         batch_norm=BatchNorm2d,
                         bias=bias,
                         preact=preact,
                         name=name,
                         instance_norm=instance_norm,
                         instance_norm_func=nn.InstanceNorm2d)


class FC(nn.Sequential):

    def __init__(self,
                 in_size: int,
                 out_size: int,
                 *,
                 activation=nn.ReLU(inplace=True),
                 bn: bool = False,
                 init=None,
                 preact: bool = False,
                 name: str = ""):
        super().__init__()

        fc = nn.Linear(in_size, out_size, bias=not bn)
        if init is not None:
            init(fc.weight)
        if not bn:
            nn.init.constant(fc.bias, 0)

        if preact:
            if bn:
                self.add_module(name + 'bn', BatchNorm1d(in_size))

            if activation is not None:
                self.add_module(name + 'activation', activation)

        self.add_module(name + 'fc', fc)

        if not preact:
            if bn:
                self.add_module(name + 'bn', BatchNorm1d(out_size))

            if activation is not None:
                self.add_module(name + 'activation', activation)
